{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0d86b3a64bcb"
      },
      "outputs": [],
      "source": [
        "# Copyright 2021 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "53271d481112"
      },
      "source": [
        "# PyTorch Lightning Training\n",
        "This notebook trains a model to predict whether the given sonar signals are bouncing off a metal cylinder or off a cylindrical rock from [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Connectionist+Bench+%28Sonar%2C+Mines+vs.+Rocks%29).\n",
        "\n",
        "This notebook is derived from the [PyTorch sample](https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/pytorch/TrainingAndPredictionWithPyTorch.ipynb). It demonstrates how to perform the same task using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), a lightweight wrapper around PyTorch.\n",
        "\n",
        "The notebook is intended to run within [AI Platform Notebooks](https://cloud.google.com/ai-platform-notebooks). The model will be trained within the notebook instance VM, optionally attached to GPUs or TPUs. With the following link, you can directly [Open in AI Platform Notebooks](https://console.cloud.google.com/ai-platform/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/ai-platform-samples/master/notebooks/samples/pytorch/lightning/TrainingAndPredictionWithPyTorchLightning.ipynb). \n",
        "\n",
        "### Dataset\n",
        "The Sonar Signals dataset that this sample uses for training is provided by the UC Irvine Machine Learning Repository. Google has hosted the data on a public GCS bucket `gs://cloud-samples-data/ai-platform/sonar/sonar.all-data`.\n",
        "\n",
        "* `sonar.all-data` is split for both training and evaluation\n",
        "\n",
        "Note: Your typical development process with your own data would require you to upload your data to GCS so that you can access that data from inside your notebook. However, in this case, Google has put the data on GCS to avoid the steps of having you download the data from UC Irvine and then upload the data to GCS.\n",
        "\n",
        "### Disclaimer\n",
        "This dataset is provided by a third party. Google provides no representation, warranty, or other guarantees about the validity or any other aspects of this dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "36c617935824"
      },
      "source": [
        "## (Optional) TPU configuration\n",
        "\n",
        "To use [Cloud TPUs](https://cloud.google.com/tpu), first create a [TPU node](https://cloud.google.com/tpu/docs/creating-deleting-tpus#setup_TPU_only). Set the **TPU software version** to a matching PyTorch version (e.g. `pytorch-1.7`) and the **Network** to the same network used for your notebook instance (e.g. `datalab-network`).\n",
        "\n",
        "Uncomment this section only if you are using TPUs. Note that you must be running this notebook on an [XLA](https://github.com/pytorch/xla) image such as [pytorch-xla.1-7](gcr.io/deeplearning-platform-release/pytorch-xla.1-7) for PyTorch to connect to Cloud TPUs. To use an XLA image, you can create a new notebook instance with the **Environment** set to `Custom container` and the **Docker container image** set to the XLA image location.\n",
        "\n",
        "If you need a quota increase for Cloud TPUs, please review the [Cloud TPU Quota Policy](https://cloud.google.com/tpu/docs/quota) for more details."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5df45a7a4dbb"
      },
      "source": [
        "### Review TPU configuration\n",
        "\n",
        "Run the gcloud command to review the available TPUs for the one you wish to use.\n",
        "Make note of the IP address (from NETWORK_ENDPOINT, without the port), and the # of TPU cores (derived from ACCELERATOR_TYPE). An ACCELERATOR_TYPE of v3-8 will indicate 8 TPU cores, for example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2516adf400fd"
      },
      "outputs": [],
      "source": [
        "# !gcloud compute tpus list --zone=YOUR_ZONE_HERE_SUCH_AS_us-central1-b"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9fb190d9c4c4"
      },
      "source": [
        "### Update TPU configuration\n",
        "\n",
        "Update the IP address and cores variables here"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8cea12b76f3d"
      },
      "outputs": [],
      "source": [
        "# tpu_ip_address='10.1.2.3'\n",
        "# tpu_cores=8"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "31977845e279"
      },
      "source": [
        "### Set TPU environment variables"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f309540500ee"
      },
      "outputs": [],
      "source": [
        "# # TPU configuration\n",
        "# %env XRT_TPU_CONFIG=tpu_worker;0;$tpu_ip_address:8470\n",
        "\n",
        "# # Use bfloat16\n",
        "# %env XLA_USE_BF16=1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8852b462f896"
      },
      "source": [
        "## Install and import packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4f7cce0e8639"
      },
      "outputs": [],
      "source": [
        "!pip install -U pytorch-lightning --quiet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7ffdeca31c4d"
      },
      "outputs": [],
      "source": [
        "from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils\n",
        "if XLADeviceUtils.tpu_device_exists():\n",
        "    import torch_xla  # noqa: F401\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "from torch.utils.data import DataLoader, Dataset, random_split\n",
        "\n",
        "import pandas as pd\n",
        "from google.cloud import storage\n",
        "\n",
        "from pytorch_lightning.core import LightningModule, LightningDataModule\n",
        "from pytorch_lightning.metrics.functional import accuracy\n",
        "from pytorch_lightning.trainer.trainer import Trainer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fd71f9fee1fe"
      },
      "source": [
        "## Environment configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "83e17af124f1"
      },
      "outputs": [],
      "source": [
        "_ = !nproc\n",
        "tpu_cores = tpu_cores if 'tpu_cores' in vars() else 0\n",
        "num_cpus = int(_[0])\n",
        "num_gpus = torch.cuda.device_count()\n",
        "device = torch.device('cuda') if num_gpus else 'cpu'\n",
        "\n",
        "print(f'Device: {device}')\n",
        "print(f'CPUs: {num_cpus}')\n",
        "print(f'GPUs: {num_gpus}')\n",
        "print(f'TPUs: {tpu_cores}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2eca6dd6b2b5"
      },
      "source": [
        "## Download data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7de884903ae9"
      },
      "outputs": [],
      "source": [
        "# Public bucket holding data for samples\n",
        "BUCKET = 'cloud-samples-data'\n",
        "\n",
        "# Path to the directory inside the public bucket containing the sample data\n",
        "BUCKET_PATH = 'ai-platform/sonar/'\n",
        "\n",
        "# Sample data file\n",
        "FILE = 'sonar.all-data'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ca9d7d6ac0c9"
      },
      "outputs": [],
      "source": [
        "bucket = storage.Client().bucket(BUCKET)\n",
        "\n",
        "blob = bucket.blob(BUCKET_PATH + FILE)\n",
        "\n",
        "blob.download_to_filename(FILE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e67a891c1178"
      },
      "source": [
        "## Define the PyTorch [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "af029dead3d0"
      },
      "outputs": [],
      "source": [
        "class SonarDataset(Dataset):\n",
        "    def __init__(self, csv_file):\n",
        "        self.dataframe = pd.read_csv(csv_file, header=None)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.dataframe)\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        # When iterating through the dataset get the features and targets\n",
        "        features = self.dataframe.iloc[idx, :-1].values.astype(dtype='float64')\n",
        "\n",
        "        # Convert the targets to binary values:\n",
        "        # R = rock --> 0\n",
        "        # M = mine --> 1\n",
        "        target = self.dataframe.iloc[idx, -1:].values\n",
        "        if target[0] == 'R':\n",
        "            target[0] = 0\n",
        "        elif target[0] == 'M':\n",
        "            target[0] = 1\n",
        "        target = target.astype(dtype='float64')\n",
        "\n",
        "        # Load the data as a tensor\n",
        "        data = {'features': torch.from_numpy(features),\n",
        "                'target': target}\n",
        "        return data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3aa68362fa7c"
      },
      "source": [
        "## Define a data processing module\n",
        "\n",
        "In this step, you will create a custom data module that extends [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html) to encapsulate the data processing steps."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "49910c2f476b"
      },
      "outputs": [],
      "source": [
        "class SonarDataModule(LightningDataModule):\n",
        "\n",
        "    def __init__(self, bucket=BUCKET, bucket_path=BUCKET_PATH, file=FILE, batch_size=32, num_workers=0):\n",
        "        super().__init__()\n",
        "\n",
        "        self.batch_size = batch_size\n",
        "        self.num_workers = num_workers\n",
        "        self.bucket = bucket\n",
        "        self.bucket_path = bucket_path\n",
        "        self.file = file\n",
        "\n",
        "    def prepare_data(self):\n",
        "        # Public bucket holding the data\n",
        "        bucket = storage.Client().bucket(self.bucket)\n",
        "\n",
        "        # Path to the data inside the public bucket\n",
        "        blob = bucket.blob(self.bucket_path + self.file)\n",
        "\n",
        "        # Download the data\n",
        "        blob.download_to_filename(self.file)\n",
        "\n",
        "    def setup(self, stage=None):\n",
        "        # Load the data\n",
        "        sonar_dataset = SonarDataset(self.file)\n",
        "\n",
        "        # Create indices for the split\n",
        "        dataset_size = len(sonar_dataset)\n",
        "        test_size = int(0.2 * dataset_size)  # Use a test_split of 0.2\n",
        "        val_size = int(0.2 * dataset_size)  # Use a test_split of 0.2\n",
        "        train_size = dataset_size - test_size - val_size\n",
        "\n",
        "        # Assign train/test/val datasets for use in dataloaders\n",
        "        self.sonar_train, self.sonar_val, self.sonar_test = random_split(sonar_dataset, [train_size, val_size, test_size])\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(self.sonar_train, batch_size=self.batch_size, num_workers=self.num_workers)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(self.sonar_val, batch_size=self.batch_size, num_workers=self.num_workers)\n",
        "\n",
        "    def test_dataloader(self):\n",
        "        return DataLoader(self.sonar_test, batch_size=self.batch_size, num_workers=self.num_workers)\n",
        "\n",
        "\n",
        "dm = SonarDataModule(num_workers=num_cpus)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3e5bedd52c06"
      },
      "source": [
        "## Define a model\n",
        "\n",
        "Next, you will create a module that extends [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html). This module includes your model code and organizes steps of the model-building process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "81e6c6ced7b3"
      },
      "outputs": [],
      "source": [
        "class SonarModel(LightningModule):\n",
        "\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "        # Define PyTorch model\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Linear(60, 60),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(p=0.2),\n",
        "            nn.Linear(60, 30),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(p=0.2),\n",
        "            nn.Linear(30, 1),\n",
        "            nn.Sigmoid()\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x.float())\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        x, y = batch['features'].float(), batch['target'].float()\n",
        "        y_hat = self(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_hat, y)\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):\n",
        "        x, y = batch['features'].float(), batch['target'].float()\n",
        "        y_hat = self(x)\n",
        "\n",
        "        loss = F.binary_cross_entropy(y_hat, y)\n",
        "\n",
        "        # Binarize the output\n",
        "        y_hat_binary = y_hat.round()\n",
        "        acc = accuracy(y_hat_binary, y.int())\n",
        "\n",
        "        # Log metrics for TensorBoard\n",
        "        self.log('val_loss', loss, prog_bar=True)\n",
        "        self.log('val_acc', acc, prog_bar=True)\n",
        "\n",
        "        return loss\n",
        "\n",
        "    def test_step(self, batch, batch_idx):\n",
        "        # Reuse validation step\n",
        "        return self.validation_step(batch, batch_idx)\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        return torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.5, nesterov=False)\n",
        "\n",
        "\n",
        "model = SonarModel()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6e5b46ae750d"
      },
      "source": [
        "## Train and evaluate the model\n",
        "\n",
        "Finally, you will create and use a [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html) to build the model and evaluate its accuracy.\n",
        "\n",
        "The trainer is initialized with an [accelerator](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#accelerator), with different options depending on your environment:\n",
        "* TPUs support only [ddp](https://pytorch-lightning.readthedocs.io/en/stable/tpu.html#distributed-backend-with-tpu), or distributed data parallel, and so `accelerator` cannot be specified. \n",
        "* GPUs support a variety of [distributed modes](https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#distributed-modes). In this notebook, we are using [dp](https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#data-parallel) for multiple GPUs on 1 machine.\n",
        "* CPUs can support [ddp_cpu](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-flags) for multi-node CPU training. For multiple CPUs on one node, there is no speed increase from using this accelerator, and so the default of `None` is used in this notebook. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5125f2b9f58f"
      },
      "outputs": [],
      "source": [
        "epochs = 100\n",
        "\n",
        "if tpu_cores:\n",
        "    trainer = Trainer(tpu_cores=tpu_cores, max_epochs=epochs)\n",
        "elif num_gpus:\n",
        "    trainer = Trainer(gpus=num_gpus, accelerator='dp', max_epochs=epochs)\n",
        "else:\n",
        "    trainer = Trainer(max_epochs=epochs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0f5dfe187909"
      },
      "outputs": [],
      "source": [
        "trainer.fit(model, dm)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f2dd4e2607c9"
      },
      "outputs": [],
      "source": [
        "trainer.test(datamodule=dm)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0444088c3027"
      },
      "source": [
        "## Save and load a trained model\n",
        "\n",
        "The following steps aren't required, but are shown for use in a production environment.\n",
        "\n",
        "First, we'll export the model to a file. Then, we'll load the model file (which isn't required in a notebook, because we already have a trained model). Finally, we'll set the model to evaluation mode (rather than train mode) for inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eef117294980"
      },
      "outputs": [],
      "source": [
        "torch.save(model.state_dict(), 'model.pt')\n",
        "\n",
        "model.load_state_dict(torch.load('model.pt'))\n",
        "\n",
        "model.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "28631cca611b"
      },
      "source": [
        "## Predict with the model \n",
        "\n",
        "Finally, let's illustrate model inference, with set values as inputs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f08338e281ac"
      },
      "outputs": [],
      "source": [
        "rock_feature = torch.tensor([[3.6800e-02, 4.0300e-02, 3.1700e-02, 2.9300e-02, 8.2000e-02, 1.3420e-01,\n",
        "                              1.1610e-01, 6.6300e-02, 1.5500e-02, 5.0600e-02, 9.0600e-02, 2.5450e-01,\n",
        "                              1.4640e-01, 1.2720e-01, 1.2230e-01, 1.6690e-01, 1.4240e-01, 1.2850e-01,\n",
        "                              1.8570e-01, 1.1360e-01, 2.0690e-01, 2.1900e-02, 2.4000e-01, 2.5470e-01,\n",
        "                              2.4000e-02, 1.9230e-01, 4.7530e-01, 7.0030e-01, 6.8250e-01, 6.4430e-01,\n",
        "                              7.0630e-01, 5.3730e-01, 6.6010e-01, 8.7080e-01, 9.5180e-01, 9.6050e-01,\n",
        "                              7.7120e-01, 6.7720e-01, 6.4310e-01, 6.7200e-01, 6.0350e-01, 5.1550e-01,\n",
        "                              3.8020e-01, 2.2780e-01, 1.5220e-01, 8.0100e-02, 8.0400e-02, 7.5200e-02,\n",
        "                              5.6600e-02, 1.7500e-02, 5.8000e-03, 9.1000e-03, 1.6000e-02, 1.6000e-02,\n",
        "                              8.1000e-03, 7.0000e-03, 1.3500e-02, 6.7000e-03, 7.8000e-03, 6.8000e-03]], dtype=torch.float64, device=device)\n",
        "rock_prediction = model(rock_feature)\n",
        "\n",
        "mine_feature = torch.tensor([[5.9900e-02, 4.7400e-02, 4.9800e-02, 3.8700e-02, 1.0260e-01, 7.7300e-02,\n",
        "                              8.5300e-02, 4.4700e-02, 1.0940e-01, 3.5100e-02, 1.5820e-01, 2.0230e-01,\n",
        "                              2.2680e-01, 2.8290e-01, 3.8190e-01, 4.6650e-01, 6.6870e-01, 8.6470e-01,\n",
        "                              9.3610e-01, 9.3670e-01, 9.1440e-01, 9.1620e-01, 9.3110e-01, 8.6040e-01,\n",
        "                              7.3270e-01, 5.7630e-01, 4.1620e-01, 4.1130e-01, 4.1460e-01, 3.1490e-01,\n",
        "                              2.9360e-01, 3.1690e-01, 3.1490e-01, 4.1320e-01, 3.9940e-01, 4.1950e-01,\n",
        "                              4.5320e-01, 4.4190e-01, 4.7370e-01, 3.4310e-01, 3.1940e-01, 3.3700e-01,\n",
        "                              2.4930e-01, 2.6500e-01, 1.7480e-01, 9.3200e-02, 5.3000e-02, 8.1000e-03,\n",
        "                              3.4200e-02, 1.3700e-02, 2.8000e-03, 1.3000e-03, 5.0000e-04, 2.2700e-02,\n",
        "                              2.0900e-02, 8.1000e-03, 1.1700e-02, 1.1400e-02, 1.1200e-02, 1.0000e-02]], dtype=torch.float64, device=device)\n",
        "mine_prediction = model(mine_feature)\n",
        "\n",
        "print('Result Values: (Rock: 0) - (Mine: 1)\\n')\n",
        "print(f'Rock Prediction:\\n\\t{\"Rock\" if rock_prediction <= 0.5 else \"Mine\"} - {rock_prediction.item()}')\n",
        "print(f'Mine Prediction:\\n\\t{\"Rock\" if mine_prediction <= 0.5 else \"Mine\"} - {mine_prediction.item()}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "TrainingAndPredictionWithPyTorchLightning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
